{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DataPipe Typing System\n",
    "\n",
    "DataPipe typing system is introduced to make the graph of DataPipes more reliable and provide type inference for users. The typing system provide the flexibility for users to determine which level(s) to have type enforcement and risk false positive errors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import IterDataPipe\n",
    "from typing import Any, Iterator, List, Tuple, TypeVar, Set, Union\n",
    "\n",
    "T_co = TypeVar('T_co', covariant=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hide traceback of Error\n",
    "import functools\n",
    "ipython = get_ipython()\n",
    "def showtraceback(self, exc_tuple=None, filename=None, tb_offset=None,\n",
    "                  exception_only=False, running_compiled_code=False):\n",
    "    try:\n",
    "        try:\n",
    "            etype, value, tb = self._get_exc_info(exc_tuple)\n",
    "        except ValueError:\n",
    "            print('No traceback available to show.', file=sys.stderr)\n",
    "            return\n",
    "\n",
    "        # Hide traceback\n",
    "        stb = self.InteractiveTB.get_exception_only(etype, value)\n",
    "\n",
    "        self._showtraceback(etype, value, stb)\n",
    "\n",
    "    except KeyboardInterrupt:\n",
    "        print('\\n' + self.get_exception_only(), file=sys.stderr)\n",
    "ipython.showtraceback = functools.partial(showtraceback, ipython)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compile-time\n",
    "Compile-time typing is enabled by default for now. And it will generate an attribute of `type` for each DataPipe. If there is no type hint specified, the DataPipe is set to a default type `Any`.\n",
    "\n",
    "### Invalid Typing\n",
    "- Return type hint of `__iter__` is not `Iterator`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Expected 'Iterator' as the return annotation for `__iter__` of InvalidDP1, but found str",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Expected 'Iterator' as the return annotation for `__iter__` of InvalidDP1, but found str\n"
     ]
    }
   ],
   "source": [
    "class InvalidDP1(IterDataPipe[int]):\n",
    "    def __iter__(self) -> str:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Return type hint of `__iter__` doesn't match or is subtype of the declared type hint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Expected return type of '__iter__' as a subtype of int, but found str for InvalidDP2",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Expected return type of '__iter__' as a subtype of int, but found str for InvalidDP2\n"
     ]
    }
   ],
   "source": [
    "class InvalidDP2(IterDataPipe[int]):\n",
    "    def __iter__(self) -> Iterator[str]:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Valid Typing\n",
    "\n",
    "- It's allowed that return type is a subtype of class type annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DP(IterDataPipe[Tuple]):\n",
    "    def __iter__(self) -> Iterator[Tuple[int, str]]:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DP(IterDataPipe):\n",
    "    def __iter__(self) -> Iterator[int]:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Default Typing (Any) with/without return hint for `__iter__`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "typing.Any\n",
      "typing.Any\n",
      "typing.Any\n"
     ]
    }
   ],
   "source": [
    "class DP(IterDataPipe):\n",
    "    def __iter__(self):\n",
    "        pass\n",
    "print(DP.type)\n",
    "class DP(IterDataPipe):\n",
    "    def __iter__(self) -> Iterator:\n",
    "        pass\n",
    "print(DP.type)\n",
    "class DP(IterDataPipe):\n",
    "    def __iter__(self) -> Iterator[T_co]:\n",
    "        pass\n",
    "print(DP.type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Matched type hints (including equal but not same types)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "typing.Tuple[+T_co, str]\n",
      "typing.Tuple[~T, str]\n"
     ]
    }
   ],
   "source": [
    "class DP(IterDataPipe[Tuple[T_co, str]]):\n",
    "    def __iter__(self) -> Iterator[Tuple[T_co, str]]:\n",
    "        pass\n",
    "print(DP.type)\n",
    "\n",
    "T = TypeVar('T', int, str)  # equals to Union[int, str]\n",
    "class DP(IterDataPipe[Tuple[T, str]]):\n",
    "    def __iter__(self) -> Iterator[Tuple[Union[int, str], str]]:\n",
    "        pass\n",
    "print(DP.type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attribute `type`\n",
    "The attribute `type` is added into each DataPipe class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_helper(cls, obj):\n",
    "    print(\"DataPipe[{}]\\nInstance type: {}\"\n",
    "          .format(cls.type, obj.type))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataPipe[typing.List[int]]\n",
      "Instance type: typing.List[int]\n"
     ]
    }
   ],
   "source": [
    "class DP(IterDataPipe[List[int]]):\n",
    "    def __iter__(self) -> Iterator[List[int]]:\n",
    "        pass\n",
    "print_helper(DP, DP())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataPipe[typing.Any]\n",
      "Instance type: typing.Any\n"
     ]
    }
   ],
   "source": [
    "class DP(IterDataPipe[Any]):\n",
    "    def __iter__(self) -> Iterator[Any]:\n",
    "        pass\n",
    "print_helper(DP, DP())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataPipe[tuple]\n",
      "Instance type: tuple\n"
     ]
    }
   ],
   "source": [
    "class DP(IterDataPipe[tuple]):\n",
    "    def __iter__(self) -> Iterator[tuple]:\n",
    "        pass\n",
    "print_helper(DP, DP())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct-time\n",
    "\n",
    "Construct-time type checking can be enabled by a decorator `argument_validation`. Users can opt in by attaching the decorator to `__init__`function, then users can run operations with the type inference of input `DataPipe`(s)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import argument_validation\n",
    "\n",
    "class DP(IterDataPipe):\n",
    "    @argument_validation\n",
    "    def __init__(self, dp: IterDataPipe[Union[int, tuple]]):\n",
    "        self.dp = dp\n",
    "\n",
    "    def __iter__(self):\n",
    "        for d in self.dp:\n",
    "            yield d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Expected argument 'dp' as a IterDataPipe, but found <class 'range'>",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Expected argument 'dp' as a IterDataPipe, but found <class 'range'>\n"
     ]
    }
   ],
   "source": [
    "dp = DP(range(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- When any input is annotated by `IterDataPipe` with detail typing hints, the `type` of input instance must be a subtype of the hint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Expected type of argument 'dp' as a subtype of hint typing.Union[int, tuple], but found str",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Expected type of argument 'dp' as a subtype of hint typing.Union[int, tuple], but found str\n"
     ]
    }
   ],
   "source": [
    "class Temp(IterDataPipe[str]):\n",
    "    def __iter__(self):\n",
    "        pass\n",
    "dp = DP(Temp())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Example of valid input `DataPipe`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Temp(IterDataPipe[Tuple[int, T_co]]):\n",
    "    def __iter__(self):\n",
    "        pass\n",
    "dp = DP(Temp())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Runtime\n",
    "\n",
    "\n",
    "### Decorator\n",
    "Runtime type checking is enabled by a decorator `runtime_validation`. Users can opt in by attaching the decorator to `__iter__` to check the output data is an instance of subtype of `type` attribute of the DataPipe.\n",
    "\n",
    "Note: This decorator is only allowed to be attached to `__iter__` for now. It can be extended into `__getitem__` and further `nonblocking` functions.\n",
    "\n",
    "`runtime_validation_disabled` is a context manager to turn off the type validaiton during runtime. It's useful for DataLoader to disable the runtime validaiton after the first epoch is finished for better performance. Note: the runtime validation is enabled by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import runtime_validation, runtime_validation_disabled\n",
    "\n",
    "class DP(IterDataPipe[Tuple[int, T_co]]):\n",
    "    def __init__(self, datasource):\n",
    "        self.ds = datasource\n",
    "        \n",
    "    @runtime_validation\n",
    "    def __iter__(self):\n",
    "        for d in self.ds:\n",
    "            yield d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Raise `RuntimeError` when the data is not of subtype\n",
    "\n",
    "- `str` is not subtype of `int`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1)\n",
      "(2, 2)\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Expected an instance as subtype of typing.Tuple[int, +T_co], but found ('3', 3)(<class 'tuple'>)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mRuntimeError\u001b[0m\u001b[0;31m:\u001b[0m Expected an instance as subtype of typing.Tuple[int, +T_co], but found ('3', 3)(<class 'tuple'>)\n"
     ]
    }
   ],
   "source": [
    "dp = DP([(1, 1), (2, 2), ('3', 3)])\n",
    "for d in dp:\n",
    "    print(d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " - Context manager to disable the runtime validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 1), (2, 2), ('3', 3)]\n"
     ]
    }
   ],
   "source": [
    "with runtime_validation_disabled():\n",
    "    print(list(dp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- `List` is not subtype of `Tuple`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1)\n",
      "(2, 2)\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Expected an instance as subtype of typing.Tuple[int, +T_co], but found [3, 3](<class 'list'>)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mRuntimeError\u001b[0m\u001b[0;31m:\u001b[0m Expected an instance as subtype of typing.Tuple[int, +T_co], but found [3, 3](<class 'list'>)\n"
     ]
    }
   ],
   "source": [
    "dp = DP([(1, 1), (2, 2), [3, 3]])\n",
    "for d in dp:\n",
    "    print(d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Context manager to disable the runtime validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 1), (2, 2), [3, 3]]\n"
     ]
    }
   ],
   "source": [
    "with runtime_validation_disabled():\n",
    "    print(list(dp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- No error will be raised when all data pass the validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 1), (2, '2'), (3, 3.0)]\n"
     ]
    }
   ],
   "source": [
    "dp = DP([(1, 1), (2, '2'), (3, 3.)])\n",
    "print(list(dp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reinforce type for DataPipe instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = TypeVar('T', int, str)\n",
    "ds = list(range(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- If the DataPipe class is not decorated with `runtime_validation` and the DataPipe instance calls `reinforce_type`, a warning will be raised."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/erjia/workspace/pytorch-dev/typing/torch/utils/data/_typing.py:346: UserWarning: The type of data generated from `DataPipe` instance won't be validated at runtime. Decorator of `runtime_validation` is required to be attached to `__iter__` method of <class '__main__.DP'> for runtime type validation\n",
      "  warnings.warn(\"The type of data generated from `DataPipe` instance won't be validated \"\n"
     ]
    }
   ],
   "source": [
    "class DP(IterDataPipe[T]):\n",
    "    def __init__(self, ds):\n",
    "        self.ds = ds\n",
    "        \n",
    "    def __iter__(self):\n",
    "        for d in self.ds:\n",
    "            yield d\n",
    "dp = DP(ds).reinforce_type(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DP(IterDataPipe[T]):\n",
    "    def __init__(self, ds):\n",
    "        self.ds = ds\n",
    "        \n",
    "    @runtime_validation\n",
    "    def __iter__(self):\n",
    "        for d in self.ds:\n",
    "            yield d"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- expected type must be a subtype of the original type hint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "Expected 'expected_type' as a subtype of ~T, but found float",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mTypeError\u001b[0m\u001b[0;31m:\u001b[0m Expected 'expected_type' as a subtype of ~T, but found float\n"
     ]
    }
   ],
   "source": [
    "dp = DP(ds).reinforce_type(float)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Integer data is not subtype of str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Expected an instance as subtype of str, but found 0(<class 'int'>)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31mRuntimeError\u001b[0m\u001b[0;31m:\u001b[0m Expected an instance as subtype of str, but found 0(<class 'int'>)\n"
     ]
    }
   ],
   "source": [
    "dp = DP(ds).reinforce_type(str)\n",
    "list(dp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Compatible with context mangager to disable validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n"
     ]
    }
   ],
   "source": [
    "with runtime_validation_disabled():\n",
    "    print(list(dp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Valid type enforcement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n"
     ]
    }
   ],
   "source": [
    "dp = DP(ds).reinforce_type(int)\n",
    "print(list(dp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Different type based on the logic of class initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DP(IterDataPipe[Union[int, str]]):\n",
    "    def __init__(self, label):\n",
    "        if label == 'int':\n",
    "            self.reinforce_type(int)\n",
    "        elif label == 'str':\n",
    "            self.reinforce_type(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "int\n"
     ]
    }
   ],
   "source": [
    "dp = DP('int')\n",
    "print(dp.type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "str\n"
     ]
    }
   ],
   "source": [
    "dp = DP('str')\n",
    "print(dp.type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "typing.Union[int, str]\n"
     ]
    }
   ],
   "source": [
    "dp = DP('')\n",
    "print(dp.type)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
